-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for graph operations #355
Conversation
Currently I have only implemented comp = jx.Compartment()
branch = jx.Branch([comp for _ in range(4)])
cell = jx.Cell([branch for _ in range(5)], parents=jnp.asarray([-1, 0, 1, 2, 2]))
net = jx.Network([cell]*3)
connect(net[0,0,0], net[1,0,0], IonotropicSynapse())
connect(net[0,0,1], net[1,0,1], IonotropicSynapse())
connect(net[0,0,1], net[1,0,1], TestSynapse())
net.cell(2).add_to_group("cell2")
net.cell(2).branch(1).add_to_group("cell2brach1")
net.cell(0).insert(Na())
net.cell(0).insert(Leak())
net.cell(1).branch(1).insert(Na())
net.cell(0).insert(K())
net.compute_xyz()
net.cell(0).branch(0).loc(0.0).record()
net.cell(0).branch(0).loc(0.0).record("m")
current = jx.step_current(i_delay, i_dur, i_amp, dt, t_max)
net.cell(0).branch(2).loc(0.0).stimulate(current)
net.cell(0).branch(1).make_trainable("Na")
net.cell(1).make_trainable("K")
net.compute_xyz()
net.cell(1).move(0,30,0)
net.cell(2).move(0,-30,0) we can just module_graph = net.to_graph()
# plot the graph
pos = {i: (n["x"], n["y"]) for i, n in module_graph.nodes(data=True)}
plt.figure(figsize=(8, 8))
nx.draw(module_graph, pos, with_labels=True, node_size=200, node_color="skyblue", font_size=8, font_weight="bold", font_color="black", font_family="sans-serif")
plt.show() and look at all its properties, i.e. checking out the soma (0) and the synapse going from node 0 to node 20 print(module_graph.nodes[0])
print(module_graph.edges[(0,20)])
|
Both |
jaxley/modules/base.py
Outdated
trainable_params = {i: {} for i in trainable_inds} | ||
for i in trainable_inds: | ||
for inds, params in zip( | ||
self.indices_set_by_trainables, self.trainable_params |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@michaeldeistler is there a scenario where there is more than one default value for an item in Module.trainable_params
? I.e. the items are all dicts of form {"Na_gNa": np.array([0.1])}
I think. Additionally, is there a case where there is more than one entry in the dict?"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok I just saw you can have multiple defaults, this code is wrong then.
But you cannot have multiple keys correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, you cannot have multiple keys.
tutorials/dev.ipynb
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ignore this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will do a more thorough review on Monday, but for now: I really like that the methods are standalone. I guess in the long run we would try to use from_graph
also for reading SWC readers?
Yes, as mentioned in the PR desc, reading swc to graph is the plan. |
Potential use-case of
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It took way more time than I would have suspected or liked to spend and two mayor iterations, but its finally mostly there! Both export and import from/to graph are working now and also the swc pipeline passes the tests now. Some more tests and cleanup, but then I think you can do a review @michaeldeistler. Small benefit of this PR is also that it reduces the difference of jaxley and NEURON very slightly (by about 20%). |
TLDR: I think I have a working version of the graph pipeline now, but tests are still not passing and I don't quite no why. Everything looks good to me. Help is much appreciated! I have been trying for ages to get graph = swc_to_graph(fname)
cell = from_graph(graph, nseg=8, max_branch_len=2000.0) reads an swc file into a networkx graph and imports the cell into jaxley. As part of While this all looks very promising, I have struggled to get the tests passing and I really dont know why: While they look very similar, they are all slightly delayed. The MSEs also show this.
Would really appreciate help. I will now be moving on to sth else for the moment though, so it is not urgent. Best, Jonas |
26432b5
to
c2bcb26
Compare
a1b70d8
to
fddb19a
Compare
0bb338c
to
e0e3f2a
Compare
I think this is finally ready for a first round of reviews. This has become quite the mammoth PR, but the functionality it enables is neat imo. For a rundown see the updated 08_morphologies.ipynb All tests are passing now, which turned out to be an immense amount of work, but the imported morphologies are similar enough to NEURON both at the compartment level (x,y,z,r,l) and they also simulate correctly. I have essentially cloned the tests in Notable changes are:
Lemme know your thoughts. Would also be happy to go through this in person. |
Just noting that in the branch It is a branch from this branch at the current stage. |
574c26e
to
b0d9f58
Compare
This has become a marathon PR, due to what it took to get the tests to pass (matching NEURON's and jaxley's morphologies by coordinates, simulating swc errors, ... just to name a few). However, I think it is on the home stretch. All tests are passing and I think all functionality is there. @michaeldeistler I left you a few comments. Your feedback would be greatly appreciated. Two things:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is super awesome! I ran the following test and it passed, could you add it to jaxley_identical
tests? I will send you the morphology on discord.
import matplotlib.pyplot as plt
import numpy as np
import jaxley as jx
from jaxley.channels import HH, Na, K, Leak
from jaxley_mech.channels.l5pc import *
gt_apical = {}
gt_soma = {}
gt_axon = {}
gt_apical["apical_NaTs2T_gNaTs2T"] = 0.026145
gt_apical["apical_SKv3_1_gSKv3_1"] = 0.004226
gt_apical["apical_M_gM"] = 0.000143
gt_soma["somatic_NaTs2T_gNaTs2T"] = 0.983955
gt_soma["somatic_SKv3_1_gSKv3_1"] = 0.303472
gt_soma["somatic_SKE2_gSKE2"] = 0.008407
gt_soma["somatic_CaPump_gamma"] = 0.000609
gt_soma["somatic_CaPump_decay"] = 210.485291
gt_soma["somatic_CaHVA_gCaHVA"] = 0.000994
gt_soma["somatic_CaLVA_gCaLVA"] = 0.000333
gt_axon["axonal_NaTaT_gNaTaT"] = 3.137968
gt_axon["axonal_KPst_gKPst"] = 0.973538
gt_axon["axonal_KTst_gKTst"] = 0.089259
gt_axon["axonal_SKE2_gSKE2"] = 0.007104
gt_axon["axonal_SKv3_1_gSKv3_1"] = 1.021945
gt_axon["axonal_CaHVA_gCaHVA"] = 0.00099
gt_axon["axonal_CaLVA_gCaLVA"] = 0.008752
gt_axon["axonal_CaPump_gamma"] = 0.00291
gt_axon["axonal_CaPump_decay"] = 287.19873
cell = jx.read_swc("morphologies/bbp_with_axon.swc", ncomp=2)
soma_inds = cell.groups["soma"]
apical_inds = cell.groups["apical"]
########## APICAL ##########
cell.apical.set("capacitance", 2.0)
cell.apical.insert(NaTs2T().change_name("apical_NaTs2T"))
cell.apical.insert(SKv3_1().change_name("apical_SKv3_1"))
cell.apical.insert(M().change_name("apical_M"))
cell.apical.insert(H().change_name("apical_H"))
for c in apical_inds:
distance = cell.scope("global").comp(c).distance(cell.branch(0).loc(0.0))
cond = (-0.8696 + 2.087* np.exp(distance*0.0031)) * 8e-5
cell.scope("global").comp(c).set("apical_H_gH", cond)
########## SOMA ##########
cell.soma.insert(NaTs2T().change_name("somatic_NaTs2T"))
cell.soma.insert(SKv3_1().change_name("somatic_SKv3_1"))
cell.soma.insert(SKE2().change_name("somatic_SKE2"))
ca_dynamics = CaNernstReversal()
ca_dynamics.channel_constants["T"] = 307.15
cell.soma.insert(ca_dynamics)
cell.soma.insert(CaPump().change_name("somatic_CaPump"))
cell.soma.insert(CaHVA().change_name("somatic_CaHVA"))
cell.soma.insert(CaLVA().change_name("somatic_CaLVA"))
cell.soma.set("CaCon_i", 5e-05)
cell.soma.set("CaCon_e", 2.0)
########## BASAL ##########
cell.basal.insert(H().change_name("basal_H"))
cell.basal.set("basal_H_gH", 8e-5)
# ########## AXON ##########
cell.insert(CaNernstReversal())
cell.set("CaCon_i", 5e-05)
cell.set("CaCon_e", 2.0)
cell.axon.insert(NaTaT().change_name("axonal_NaTaT"))
cell.axon.insert(KTst().change_name("axonal_KTst"))
cell.axon.insert(CaPump().change_name("axonal_CaPump"))
cell.axon.insert(SKE2().change_name("axonal_SKE2"))
cell.axon.insert(CaHVA().change_name("axonal_CaHVA"))
cell.axon.insert(KPst().change_name("axonal_KPst"))
cell.axon.insert(SKv3_1().change_name("axonal_SKv3_1"))
cell.axon.insert(CaLVA().change_name("axonal_CaLVA"))
########## WHOLE CELL ##########
cell.insert(Leak())
cell.set("Leak_gLeak", 3e-05)
cell.set("Leak_eLeak", -75.0)
cell.set("axial_resistivity", 100.0)
cell.set("eNa", 50.0)
cell.set("eK", -85.0)
cell.set("v", -65.0)
for key in gt_apical.keys():
cell.apical.set(key, gt_apical[key])
for key in gt_soma.keys():
cell.soma.set(key, gt_soma[key])
for key in gt_axon.keys():
cell.axon.set(key, gt_axon[key])
dt = 0.025
t_max = 100.0
time_vec = np.arange(0, t_max+2*dt, dt)
cell.delete_stimuli()
cell.delete_recordings()
i_delay = 10.0
i_dur = 80.0
i_amp = 3.0
current = jx.step_current(i_delay, i_dur, i_amp, dt, t_max)
cell.scope("global").comp(soma_inds[0]).stimulate(current) # Stimulate soma
cell.scope("global").comp(soma_inds[0]).record()
cell.set("v", -65.0)
cell.init_states()
voltages = jx.integrate(cell)
voltages_250130 = jnp.asarray([[-65. , -66.22422623, -67.23001452, -68.06298803,
-68.75766951, -33.91317711, -55.24503749, -46.11452291,
-42.18960646, -51.12861864, -43.65442616, -40.62727385,
-49.56110473, -43.24030949, -36.71731271, -48.7405489 ,
-42.98507829, -34.64282586, -48.24427898, -42.6412365 ,
-34.70568206, -47.90643598, -42.15688181, -36.17711814,
-47.65564274, -41.52265914, -38.1627371 , -47.44680473,
-40.70730741, -40.15298353, -47.25483146, -39.63994798,
-41.96818737, -47.06569105, -38.17257448, -43.50053648,
-46.87517934, -65.40488865, -69.96981343, -72.24384111,
-73.46204372]])
max_error = np.max(np.abs(voltages[:, ::100] - voltages_081123))
tolerance = 1e-8
assert max_error <= tolerance, f"Error is {max_error} > {tolerance}"
I have to admit that I did not do a detailed review of the rest of the code---but given that I love the API and the tutorial and all tests (including complicated ones as described above) pass, we should merge this.
One remaining question: do I understand correctly that read_swc
does the same thing as before? IIUC, we do not yet have any SWC->networkX function, right? One has to go via the Jaxley
swc reader?
Thanks so much!
Good to know its robust to new morphologies :D I will add this to the testcases.
I can give you a rundown when your back next week. Before I merge this though, @michaeldeistler could you just briefly go through the points in my own review above? Would appreciate if you could leave your opinion.
I kept the old pipeline completely intact (incl. There is also a lot of context in the docstrings. :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-.- forgot to add a final message, so the review was pending the whole time ... here it is
@michaeldeistler marked the test to be skipped for now. Lemme know if you're ok with merging this. |
This PR adds support for graph methods. It implements a
to_graph
andfrom_graph
method in module.Todos:
to_graph
to_graph
from_graph
from_graph
- [ ] ensure it works for views and modules, with MakeView
behave more likeModule
#351, this should be straight forward..swc
->nx.DiGraph
->Cell
functionalityThoughts I had while coding this up:
self.nodes
andself.edges
from it.attrs
from inmodule.edges
,module.branch_edges
andmodule.nodes
, are stored as node / edge attrs in theDiGraph
, this could also be used for plotting or invis
.DiGraph.graph
,DiGraph.nodes
,DiGraph.edges
, this could also be used to save and share modules.module.nodes
, not sure if onegroup
column that contains lists of groups or several boolean cols withgroup_name
is the better solution.View
behave more likeModule
#351 is merged, since this implements all properties that are attached to the nodes / edges for views of modules.EDIT: